{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 多头关注\n",
    ":label:`sec_multihead-attention`\n",
    "\n",
    "实际上，鉴于查询、键和值集相同，我们可能希望我们的模型将来自同一注意机制不同行为的知识结合起来，例如捕获序列内各种范围的依赖关系（例如，短范围与长距离）。因此，允许我们的注意机制共同使用查询、键和值的不同表示子空间可能是有益的。\n",
    "\n",
    "为此，可以使用 $h$ 独立学习的线性投影来转换查询、键和值，而不是执行单一的注意力集中。然后，这些 $h$ 个预计查询、键和值将并行输入注意力集中。最后，$h$ 注意力集中输出被连接在一起，并与另一个学习的线性投影进行转换，以产生最终输出。这种设计被称为 * 多头注意 *，其中 $h$ 注意力池输出中的每个都是 * 头 * :cite:`Vaswani.Shazeer.Parmar.ea.2017`。:numref:`fig_multi-head-attention` 使用完全连接的图层来执行可学习的线性变换，描述了多头注意力。\n",
    "\n",
    "![Multi-head attention, where multiple heads are concatenated then linearly transformed.](../img/multi-head-attention.svg)\n",
    ":label:`fig_multi-head-attention`\n",
    "\n",
    "## 模型\n",
    "\n",
    "在提供多头关注的实施之前，让我们以数学方式将这个模型正式化。给定查询 $\\mathbf{q} \\in \\mathbb{R}^{d_q}$、一个键 $\\mathbf{k} \\in \\mathbb{R}^{d_k}$ 和一个值 $\\mathbf{v} \\in \\mathbb{R}^{d_v}$，每个注意头 $\\mathbf{h}_i$ ($i = 1, \\ldots, h$) 的计算方法为\n",
    "\n",
    "$$\\mathbf{h}_i = f(\\mathbf W_i^{(q)}\\mathbf q, \\mathbf W_i^{(k)}\\mathbf k,\\mathbf W_i^{(v)}\\mathbf v) \\in \\mathbb R^{p_v},$$\n",
    "\n",
    "其中，可学习的参数 $\\mathbf W_i^{(q)}\\in\\mathbb R^{p_q\\times d_q}$、$\\mathbf W_i^{(k)}\\in\\mathbb R^{p_k\\times d_k}$ 和 $\\mathbf W_i^{(v)}\\in\\mathbb R^{p_v\\times d_v}$ 以及 $f$ 是注意力集中，例如 :numref:`sec_attention-scoring-functions` 中的添加剂注意力和扩大点产品注意力。多头注意力输出是另一种线性转换，通过 $h$ 头连接的可学习参数 $\\mathbf W_o\\in\\mathbb R^{p_o\\times h p_v}$：\n",
    "\n",
    "$$\\mathbf W_o \\begin{bmatrix}\\mathbf h_1\\\\\\vdots\\\\\\mathbf h_h\\end{bmatrix} \\in \\mathbb{R}^{p_o}.$$\n",
    "\n",
    "基于这种设计，每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## 实施\n",
    "\n",
    "在我们的实施过程中，我们为多头关注的每个人选择缩放的点产品注意力。为避免计算成本和参数化成本的显著增长，我们设置了 $p_q = p_k = p_v = p_o / h$。请注意，如果我们将查询、键和值的线性变换的输出数量设置为 $p_q h = p_k h = p_v h = p_o$，则可以并行计算 $h$ 头。在下面的实现中，$p_o$ 是通过参数 `num_hiddens` 指定的。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, key_size, query_size, value_size, num_hiddens,\n",
    "                 num_heads, dropout, bias=False, **kwargs):\n",
    "        super(MultiHeadAttention, self).__init__(**kwargs)\n",
    "        self.num_heads = num_heads\n",
    "        self.attention = d2l.DotProductAttention(dropout)\n",
    "        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)\n",
    "        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)\n",
    "        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)\n",
    "        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)\n",
    "\n",
    "    def forward(self, queries, keys, values, valid_lens):\n",
    "        # Shape of `queries`, `keys`, or `values`:\n",
    "        # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)\n",
    "        # Shape of `valid_lens`:\n",
    "        # (`batch_size`,) or (`batch_size`, no. of queries)\n",
    "        # After transposing, shape of output `queries`, `keys`, or `values`:\n",
    "        # (`batch_size` * `num_heads`, no. of queries or key-value pairs,\n",
    "        # `num_hiddens` / `num_heads`)\n",
    "        queries = transpose_qkv(self.W_q(queries), self.num_heads)\n",
    "        keys = transpose_qkv(self.W_k(keys), self.num_heads)\n",
    "        values = transpose_qkv(self.W_v(values), self.num_heads)\n",
    "\n",
    "        if valid_lens is not None:\n",
    "            # On axis 0, copy the first item (scalar or vector) for\n",
    "            # `num_heads` times, then copy the next item, and so on\n",
    "            valid_lens = torch.repeat_interleave(valid_lens,\n",
    "                                                 repeats=self.num_heads,\n",
    "                                                 dim=0)\n",
    "\n",
    "        # Shape of `output`: (`batch_size` * `num_heads`, no. of queries,\n",
    "        # `num_hiddens` / `num_heads`)\n",
    "        output = self.attention(queries, keys, values, valid_lens)\n",
    "\n",
    "        # Shape of `output_concat`:\n",
    "        # (`batch_size`, no. of queries, `num_hiddens`)\n",
    "        output_concat = transpose_output(output, self.num_heads)\n",
    "        return self.W_o(output_concat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "为了允许多个头的并行计算，上面的 `MultiHeadAttention` 类使用了下面定义的两个移调函数。具体来说，`transpose_output` 函数逆转了 `transpose_qkv` 函数的操作。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "#@save\n",
    "def transpose_qkv(X, num_heads):\n",
    "    # Shape of input `X`:\n",
    "    # (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).\n",
    "    # Shape of output `X`:\n",
    "    # (`batch_size`, no. of queries or key-value pairs, `num_heads`,\n",
    "    # `num_hiddens` / `num_heads`)\n",
    "    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)\n",
    "\n",
    "    # Shape of output `X`:\n",
    "    # (`batch_size`, `num_heads`, no. of queries or key-value pairs,\n",
    "    # `num_hiddens` / `num_heads`)\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "\n",
    "    # Shape of `output`:\n",
    "    # (`batch_size` * `num_heads`, no. of queries or key-value pairs,\n",
    "    # `num_hiddens` / `num_heads`)\n",
    "    return X.reshape(-1, X.shape[2], X.shape[3])\n",
    "\n",
    "#@save\n",
    "def transpose_output(X, num_heads):\n",
    "    \"\"\"Reverse the operation of `transpose_qkv`\"\"\"\n",
    "    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])\n",
    "    X = X.permute(0, 2, 1, 3)\n",
    "    return X.reshape(X.shape[0], X.shape[1], -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "让我们使用键和值相同的玩具示例来测试我们实施的 `MultiHeadAttention` 类。因此，多头注意输出的形状是（`batch_size`、`num_queries`、`num_hiddens`）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MultiHeadAttention(\n",
       "  (attention): DotProductAttention(\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       "  (W_q): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_k): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_v): Linear(in_features=100, out_features=100, bias=False)\n",
       "  (W_o): Linear(in_features=100, out_features=100, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_hiddens, num_heads = 100, 5\n",
    "attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,\n",
    "                               num_hiddens, num_heads, 0.5)\n",
    "attention.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 100])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([\n",
    "    3, 2])\n",
    "X = torch.ones((batch_size, num_queries, num_hiddens))\n",
    "Y = torch.ones((batch_size, num_kvpairs, num_hiddens))\n",
    "attention(X, Y, Y, valid_lens).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 13
   },
   "source": [
    "## 摘要\n",
    "\n",
    "* 多头关注通过查询、键和值的不同表示子空间将同一注意力集中的知识结合起来。\n",
    "* 要并行计算多头多头注意力，需要适当的张量操作。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 在这个实验中，可视化多个头部的注意力重量。\n",
    "1. 假设我们有一个基于多头注意力的训练有素的模型，我们希望修剪最不重要的注意力头以提高预测速度。我们如何设计实验来衡量注意头的重要性？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1635)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}